"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


compute flow stability clustering for a given autocovariance integral matrix (non-sparse version)


"""
import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))

import numpy as np
from TemporalStability import (Clustering,
                               norm_var_information)
import pickle
import time
from multiprocessing import Pool
from itertools import combinations
from argparse import ArgumentParser
import pandas as pd
from collections import Counter
#raise Exception
import traceback


#%%

ap = ArgumentParser()

ap.add_argument('--datadir', default='', type=str)
ap.add_argument('--savedir', default='', type=str)
ap.add_argument('--ncpu', default=4, type=int)
ap.add_argument('--nnode', default=1, type=int)
ap.add_argument('--node_num', default=0, type=int)
ap.add_argument('--num_repeat', default=50, type=int)
ap.add_argument('--n_iter_max', default=100, type=int)
ap.add_argument('--verbose', default=0, type=int)
ap.add_argument('--clust_verbose', default=0, type=int)
ap.add_argument('--activity_driven_nm', action='store_true')
ap.add_argument('--compute_static_clustering', action='store_true')
ap.add_argument('--not_compute_source_clustering', action='store_false')
ap.add_argument('--not_compute_sym_clustering', action='store_false')
ap.add_argument('--net_name', default='synthtemp_heira_big', type=str)


inargs = vars(ap.parse_args())
datadir = inargs['datadir']
savedir = inargs['savedir']
ncpu = inargs['ncpu']
nnode = inargs['nnode']
node_num = inargs['node_num']
num_repeat = inargs['num_repeat']
n_iter_max = inargs['n_iter_max']
verbose = inargs['verbose']
clust_verbose = inargs['clust_verbose']
activity_driven_nm = inargs['activity_driven_nm']
net_name = inargs['net_name']

#%%
if datadir == '':
    raise Exception('datadir must be given')

if savedir == '':
    raise Exception('savedir must be given')



#%%

files = [f for f in os.listdir(datadir) if f.startswith(net_name + '_tau_w')]


#%% clustering

def worker(file_args):
    
    file = file_args
    t0 = time.time()
    
    
    savefile = os.path.join(savedir,'clusters_' + file)
    
    if os.path.exists(savefile):
        print('PID ', os.getpid(), 'file already exists, skipping', savefile)
    else:
        print('PID ', os.getpid(), ' starting file ', file)
        try:
            
                
            # S is the autocov integral I
            S = pd.read_pickle(os.path.join(datadir,file))
            
            p1 = None
            p2 = None
            T = None
        
        
            
                
        
            print('PID ', os.getpid(), 'computing symmetric clusters ', ', file ', file)
        
            sym_clusters = []
            sym_stabilites = []
            
            
            t0 = time.time()
            
            for i in range(num_repeat):
                if verbose:
                    print('**** PID ', os.getpid(), 'sym ', i)
                    
            
                clustering = Clustering(p1=p1, p2=p2, T=T, S=S)
                
                clustering.find_louvain_clustering(n_iter_max=n_iter_max,
                                                   verbose=clust_verbose)
            
                sym_clusters.append(clustering.partition.cluster_list)
                sym_stabilites.append(clustering.compute_stability())
            
                
            t1 = time.time()
            
            print('PID ', os.getpid(), ' symm clust, took ', t1-t0, ', file ', file)
            print('PID ', os.getpid(), ' computing sym nvi ', ', file ', file)
        
        
            nvarinf_sym = np.mean([norm_var_information(c1,c2) for c1,c2 in combinations(sym_clusters,2)])
            
            clust_counter_sym = Counter([tuple(sorted([tuple(sorted(c)) for c in clust])) \
                                         for clust in sym_clusters])

            best_cluster_sym = sym_clusters[np.argmax(sym_stabilites)]
            
            best_stab_sym = max(sym_stabilites)
            
            avg_stab_sym = np.mean(sym_stabilites)
                
            avg_nclust_sym = np.mean([len(c) for c in sym_clusters])
                 
        
            t2 = time.time()
            print('PID ', os.getpid(), ' symm nvi, took ', t2-t1, ', file ', file)
    
        
            print('PID ', os.getpid(), ' saving to file', savefile )
            
            res = {'num_repeat' : num_repeat}
        
            res['clust_counter_sym'] = clust_counter_sym
            res['sym_stabilites'] = sym_stabilites
            res['nvarinf_sym'] = nvarinf_sym
            res['avg_stab_sym'] = avg_stab_sym
            res['avg_nclust_sym'] = avg_nclust_sym
            res['best_cluster_sym'] = best_cluster_sym
            res['best_stab_sym'] = best_stab_sym
                
            with open(savefile, 'wb') as fopen:
                pickle.dump(res, fopen)
    
        except Exception as e:
            print('PID ', os.getpid(), '-+-+-+ Exception at file:', file,
                  file=sys.stdout)
            print('PID ', os.getpid(), '-+-+-+ Exception at file:', file,
                   file=sys.stderr)
            traceback.print_exc(file=sys.stderr)
                        
    print('+++ PID ', os.getpid(), 'finished in ', time.time()-t0)
#%%
t00 = time.time()
print('starting pool of {0} cpus'.format(ncpu))
with Pool(ncpu) as p:
    work = p.map_async(worker,
               [(file) for file in files])
    data = work.get()
    
    
print('***** Finished! in {0}'.format(time.time()-t00))    
    
